function parameters = ComputeMAandMResWithEllipsoids(Datas, parameters, methods)
%Finds the number of Class A points lying inside the 0.95-Ellipsoid
%determined by Class B plut the number of Class B points lying inside the
%0.95-Ellipsoid determined by Class A;

colors = lines(2);

if ~isempty(parameters.snapshots.k1) & ~isempty(parameters.multilevel.Mres), return, end
if ismember(parameters.multilevel.svmonly, [0,1]), return, end
if isempty(parameters.multilevel.concentration), parameters.multilevel.concentration = 0.95; end



%% Balance Data Sets
% AData = Datas.rawdata.AData; %Create Backup
% BData= Datas.rawdata.BData; %Create Backup
% 
% AData = AData - mean(AData,2); %Subtract Class A Mean
% BData = BData - mean(AData,2); %Subtract Class A Mean
% 
% 
% NA = size(Datas.rawdata.AData,2);
% NB = size(Datas.rawdata.BData,2);
% 
% ZMA = AData; %AData(:,1:NB);
% ZMB = BData - mean(BData,2);

%% Prep data
Classes = 'AB';
for I = 'ij', parameters.data.(I) = 1; end
for Class = Classes, parameters.data.(Class) = size(Datas.rawdata.([Class 'Data']), 2); end
fprintf('Computing Ideal truncation MA and residual dimension Mres\n')

%% Get maximum allowable truncation parameter size
maxMA = GetMaxTrunc(parameters);
switch isempty(parameters.snapshots.k1)
    case true, MA = 2:maxMA;
    case false, MA = parameters.snapshots.k1;
end
NinWrongEllipse = nan(maxMA, parameters.data.numofgene);


Misplaced = Inf;
BestMA = [];
BestMres = [];

%% Prep Data
Datas = methods.all.prepdata(Datas, parameters);
BackupDatas = Datas;


figure(1), ax = axis; hold on, axis square
for ima = MA 
    %Truncation parameter. numofgene - ima = dim(orth subspace)
    
    parameters.snapshots.k1 = ima;
    dimOrth = parameters.data.numofgene - ima;
for imres = 2:dimOrth
       
        Datas = methods.Multi2.SepFilter(Datas, parameters, methods,imres);
  
        stop = mod(ima, floor(maxMA/10)) == 0; cla
        for i = 1:2

        fprintf('MA = %d, Mres = %d, \n', ima, imres);
        %if ima == 9 && imres == 3, keyboard, end

        %% Construct new covariance eigendata
        Class = Classes(i);
        NTraining = size(Datas.(Class).Training, 2);
        %Datas = UpdateData(S, Datas, parameters);
        %[~, Eval, ~, Evec, ~] = snapshotssub2(Datas.(Class).Training, imres);
        [Evec, Eval] = mySVDreduced(Datas.(Class).Training);
        ClassMean = mean(Datas.(Class).Training, 2);
        %[Evec, Eval] = UpdateEigendata(Eval, S, parameters);
        [Evec, Eval] = trimEigendata(Evec, Eval);
        stop = stop & length(Eval) == 2 & size(Evec,1) == 2;

        
        

        %% Find radius which captures 95% of the points in each class
        %radius = FindPercentileRadius(Datas.(Class).Training, ClassMean, Evec, Eval, parameters.multilevel.concentration);
        radius = FindPercentileRadius(Datas.(Class).Training, ClassMean, Evec, Eval, parameters.multilevel.concentration);

        %% Develop check and to determine if a point y lies inside an ellipse
        IsInEllipseInline{i} = @(Y) IsInEllipse(Y, ClassMean, Evec, Eval, radius);

        % Functions to visualize data in two-dimensions
        if stop
          %  Q = (Eval.^0.5) .* Evec; %Transforms unit circle into ellipse
            plotEllipseInline{i} = @()  plotEllipse(Evec, Eval, radius, ClassMean, gca); 
            scatterInline{i} = @() scatter(Datas.(Class).Training(1,:), ...
                                           Datas.(Class).Training(2,:), ...
                                           36, colors(i,:));
            scatterInline2{i} = @() scatter(Datas.(Class).Testing(1,:), ...
                                           Datas.(Class).Testing(2,:), ...
                                           36, colors(i,:), 'filled');
        end
        end
        

        %% Determine number of points from each class inside the wrong ellipse
        NinWrongEllipse(ima, imres) = 0;
        for i = 1:2  

            j = 2 - i + 1;
            Class = Classes(i);
            EllipseCheck = IsInEllipseInline{j}(Datas.(Class).Training);
            if stop
                plotEllipseInline{i}();
                scatterInline{i}();
                scatterInline2{i}();
            end  
            NinWrongEllipse(ima, imres) = NinWrongEllipse(ima, imres) + sum( EllipseCheck );
        end

        if NinWrongEllipse(ima, imres) < Misplaced
            Misplaced = NinWrongEllipse(ima, imres);
            BestMA = ima; BestMres = imres;
        end

        if stop
            title( sprintf('MA = %d, Misplaced = %d', ima, NinWrongEllipse(ima, imres) ))
            pause(5)
        end
        




        Datas = BackupDatas;
end
end


parameters.snapshots.k1 = BestMA;
N = NinWrongEllipse(BestMA,:);



parameters.multilevel.Mres = setMres( find(N == Misplaced) , parameters) ;
fprintf('MA: %d\nMres: %d \n Minimum points in wrong ellipsoid:%d\n\n', BestMA, BestMres, Misplaced)


plotHeatMap(NinWrongEllipse)

figure('Name', 'Histogram')
histogram(NinWrongEllipse(~isnan(NinWrongEllipse)))



close all
end

%% ========================================================================
%% Auxillary Functions: 
%% ========================================================================

%==========================================================================
function maxTrunc = GetMaxTrunc(parameters)
minTrainingA = parameters.data.A - parameters.Kfold;
minTestingB = max(parameters.Kfold, mod(parameters.data.B, parameters.Kfold) );
maxTrainingB = parameters.data.B - minTestingB;
maxTrunc = minTrainingA - maxTrainingB;
maxTrunc = min(parameters.data.numofgene, maxTrunc);
end
%==========================================================================

%==========================================================================
function [Evec, Eval] = mySVDfull(data)
NFeatures = size(data,1);
NSamples = size(data, 2);
mx = mean(data,2);
data = data - mx;
data = data * sqrt(1/(NSamples - 1));
[Evec, Eval,~] = svd(data*data', 'vector');
%Eval = Eval.^2;

if length(Eval) < NFeatures
    m = length(Eval);
    missingZeros = zeros(NFeatures - m,1);
    Eval = [Eval(:) ; missingZeros];
end

%Test
% A1 = data * data';
% A2 = Evec * (Eval .* Evec');
% disp(norm(A1 - A2));

end
%==========================================================================

%==========================================================================
function [Evec, Eval] = mySVDreduced(data)
NFeatures = size(data,1); % number of data points
NSamples = size(data, 2);

if NFeatures <= NSamples %data matrix is wide
    [Evec, Eval] = mySVDfull(data);
    return
end


if NFeatures > NSamples %Data matrix is tall  
    mx = mean(data,2);
    data = data - mx;
    data = data * sqrt(1/(NSamples - 1));
    C = data' * data;    
    [~,S,V] = svd(C, 'vector');
    S = S(:)'; 
    Evec = (data * V) ./ sqrt(S);
    Eval = S';
end

%Test
A1 = data * data';
A2 = Evec * (Eval .* Evec');
disp(norm(A1 - A2));
end
%==========================================================================

%==========================================================================
function [Evec, Eval] = mysvd(data, k)
NFeatures = size(data,1); % number of data points
NSamples = size(data, 2);

mx = mean(data,2); % get the mean of the dataset
data = data - mx; % center the data so the mean is 0
data = data * sqrt(1/(NFeatures - 1));

if NFeatures <= NSamples %Data matrix is tall   
    C = data * data';
elseif NFeatures > NSamples %Data matrix is wide
     C = data' * data;
end

[~,S,V] = svd(C, k);
S = diag(S); S = S(:)'; S = sqrt(S);

%U = (data * V) ./ S;

if NFeatures <= NSamples %Data matrix is tall 
    Evec = (data * V) ./ S;
elseif NFeatures > NSamples %Data matrix is wide
    Evec = V;
end
Eval = S';
end
%==========================================================================



%==========================================================================
function [Evec, Eval] = trimEigendata(Evec, Eval)
%Deletes zero eigenvalues and eigenvectors
Eval = Eval(:);
tol = eps * max(size(Evec)) * max(Eval);
isnonzero = Eval > tol;
Eval = Eval(isnonzero);
Evec = Evec(:, isnonzero);
end
%==========================================================================

%==========================================================================
% function mysvd = mysnapshot(parameters)
% switch parameters.multilevel.eigentag
%     case 'largest', mysvd = @snapshotssub2;
%     case 'smallest', mysvd = @snapshotssub3;
% end
% end
%==========================================================================

%==========================================================================
function [Evec] = ComputeOptimalSubspace(Datas, parameters)
ResidA = Datas.A.eigenvectors(:, (parameters.snapshots.k1 + 1):end);
%ZMBT = Datas.B.Training -  mean(Datas.B.Training,2);
%dimSubspace = parameters.multilevel.Mres;
%projCov = ResidA' * ZMBT;
projCov = ResidA' * Datas.B.Training;
mysvd = mysnapshot(parameters);

[~, ~, ~, Evec] = mysvd(projCov, parameters.multilevel.Mres);
end
%==========================================================================

%==========================================================================
function Datas = UpdateData(S, Datas, parameters)
ResidA = Datas.A.eigenvectors(:, (parameters.snapshots.k1 + 1):end);
P = ResidA * S;
for i = 'AB', for set = ["Training", "Testing"]
        Datas.(i).(set) = P' * Datas.(i).(set);
end, end
end
%==========================================================================

%==========================================================================
function [Evec, Eval] = UpdateEigendataA(Eval, S, parameters)
Nevals = size(S,2);
switch parameters.multilevel.eigentag
    case 'largest', Eval = Eval(1:Nevals);
    case 'smallest', Eval = Eval(end-Nevals+1:end);
end
C = (Eval(:)'.^0.5) .* S;
[Evec, Eval, ~] = svd(C, 'vector');
end
%==========================================================================

%==========================================================================
function [Evec, Eval] = UpdateEigendataB(Eval, S, parameters)
Nevals = length(Eval);
I = speye(parameters.data.numofgene);

switch parameters.multilevel.eigentag
    case 'largest'
        Eval = Eval(1:Nevals);
        Evec = I(:,1:Nevals);
    case 'smallest'
        Eval = Eval(end-Nevals+1:end);
        Evec = I(:,end-Nevals+1:end);
end
end
%==========================================================================

%==========================================================================
function [Evec, Eval] = UpdateEigendata(Eval, S, Class, parameters)
    switch Class
        case 'A', [Evec, Eval] = UpdateEigendataA(Eval, S, parameters);
        case 'B', [Evec, Eval] = UpdateEigendataB(Eval, S, parameters);
    end
end
%==========================================================================

%==========================================================================
function radius = FindPercentileRadius(Y,center, Evec, Eval, percentile)
%Finds the radius r such that Ellipse:
% Eval.*Evec*(Y - center)x: x'x < r^2 contains percentile of the data
% points

w = Y - center;
sphere = (1./Eval) .* (Evec' * w);
radius = quantile( sqrt(sum(sphere.^2, 1)), percentile);

end
%==========================================================================

%==========================================================================
function idx = IsInSpan(Y, center, Evec)
%determines whether or not the vectors in Y - center lies in the span of the
%eigenvectors of the matrix Evec
tol = eps * max(size(Evec));
w = Y - center;
residuals = (Evec * Evec' * w) - w;

idx = sqrt(sum(residuals.^2,1)) < tol;
end
%==========================================================================

%==========================================================================
function idx = SatisfiesNorm(Y, center, Evec, Eval, radius)
%determines wether or not the vectors in Y - center lie in the ellipse
%specified by 
% Evec * Eval. * Evec' * x, as x ranges over the unit ball. 
Eval = Eval(:);
w = Y - center;
x =  (1./Eval) .* (Evec' * w);
idx = sqrt(sum( x.^2, 1)) < radius;
end
%==========================================================================

%==========================================================================
function idx = IsInEllipse(Y, center, Evec, Eval, radius)
idx = IsInSpan(Y, center, Evec) & SatisfiesNorm(Y, center, Evec, Eval, radius);
end
%==========================================================================

%==========================================================================
function plotEllipse(Evec, Eval, r, center, ax)
%plots the ellipse given by x' * Q * x = r, for PSD Q and radius r on the
%axes ax
Q = Eval.*Evec;
t = linspace(0,2*pi,500); t = t(:)';
ellipse = r*Q*[cos(t) ; sin(t)] + center(:) ;
plot(ax, ellipse(1,:), ellipse(2,:), 'LineWidth', 2, 'Color', 'k');
end
%==========================================================================

%==========================================================================
function plotHeatMap(Data)
figure('Name', 'In Wrong Ellipsoid'), imagesc(Data), 
J = jet; J(1,:) = [1,1,1]; J(end,:) = [0,0,0]; 
colormap(J), colorbar
xlabel('Mres'), ylabel('MA')
end
%==========================================================================

%==========================================================================
function Y = setMres(Mres, parameters)
%X is a vector of values of Mres for which the separation criterion attains
%a mininimum. Y represent linearly spaced integers b

if length(Mres) < parameters.multilevel.l
    Y = Mres; return
end

Y = linspace(min(Mres), max(Mres)-1, parameters.multilevel.l);
Y = ceil(Y);
I = knnsearch(Mres(:), Y(:));
Mres = Mres(:); I = I(:);
Y = Mres(I);
Y = Y(:)';
end
%==========================================================================

    

